Phase Mask Design¶
In this notebook, we will illustrate the problem of inverse design of a phase mask: we will choose the example from Wong et al, 2021, designing a diffractive pupil phase mask for the Toliman telescope.
In order to get high precision centroids, we need to maximize the gradient energy of the pupil; in order to satisfy fabrication constraints, we need a binary mask with phases only in {0, π}.
# Core jax
import jax
import jax.numpy as np
import jax.random as jr
# Optimisation
import equinox as eqx
import optax
# Optics
import dLux as dl
# Plotting/visualisation
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
%matplotlib inline
plt.rcParams['image.cmap'] = 'inferno'
plt.rcParams["font.family"] = "serif"
plt.rcParams['figure.dpi'] = 120
/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/chex/_src/pytypes.py:37: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead. PyTreeDef = type(jax.tree_structure(None))
We will first generate an orthonormal basis for the pupil phases, and then threshold this to {0, 1} while preserving soft edges using the Continuous Latent Image Mask Binarization (CLIMB) algorithm from the Wong et al paper.
Generate basis vectors however you like - in this case we are using logarithmic radial harmonics and sines and cosines in θ., but you can do whatever you like here. This code is not important; just generate your favourite not-necessarily-orthonormal basis, and we will use PCA to orthonormalize it later on.
# Define arrays sizes, samplings, symmetries
wf_npix = 256
oversample = 3
nslice = 3
# Define basis hyper parameters
a = 10
b = 8
ith = 10
# Define coordinate grids
npix = wf_npix * oversample
c = (npix - 1) / 2.
xs = (np.arange(npix) - c) / c
XX, YY = np.meshgrid(xs, xs)
RR = np.sqrt(XX ** 2 + YY ** 2)
PHI = np.arctan2(YY, XX)
# Generate basis vectors to map over
As = np.arange(-a, a+1)
Bs = nslice * np.arange(0, b+1)
Cs = np.array([-np.pi/2, np.pi/2])
Is = np.arange(-ith, ith+1)
# Define basis functions
LRHF_fn = lambda A, B, C, RR, PHI: np.cos(A*np.log(RR + 1e-12) + B*PHI + C)
sine_fn = lambda i, RR: np.sin(i * np.pi * RR)
cose_fn = lambda i, RR: np.cos(i * np.pi * RR)
# Map over basis functions
gen_LRHF_basis = jax.vmap(jax.vmap(jax.vmap( \
LRHF_fn, (None, 0, None, None, None)),
(0, None, None, None, None)),
(None, None, 0, None, None))
gen_sine_basis = jax.vmap(sine_fn, in_axes=(0, None))
gen_cose_basis = jax.vmap(cose_fn, in_axes=(0, None))
# Generate basis
LRHF_basis = gen_LRHF_basis(As, Bs, Cs, RR, PHI) \
.reshape([len(As)*len(Bs)*len(Cs), npix, npix])
sine_basis = gen_sine_basis(Is, RR)
cose_basis = gen_cose_basis(Is, RR)
# Format shapes and combine
LRHF_flat = LRHF_basis.reshape([len(As)*len(Bs)*len(Cs), npix*npix])
sine_flat = sine_basis.reshape([len(sine_basis), npix*npix])
cose_flat = cose_basis.reshape([len(cose_basis), npix*npix])
full_basis = np.concatenate([
LRHF_flat,
sine_flat,
cose_flat
])
Orthonormalize with PCA - could also use Gram-Schmidt if you prefer.
%%time
from sklearn.decomposition import PCA
pca = PCA().fit(full_basis)
components = pca.components_.reshape([len(full_basis), npix, npix])
components = np.copy(components[:99,:,:])
basis = np.concatenate([np.mean(pca.mean_)*np.array(np.ones((1,npix,npix))), components])
CPU times: user 2min 58s, sys: 1.78 s, total: 3min Wall time: 33.1 s
Show the pretty basis vectors:
nfigs = 100
ncols = 10
nrows = 1 + nfigs//ncols
plt.figure(figsize=(4*ncols, 4*nrows))
for i in range(nfigs):
plt.subplot(nrows, ncols, i+1)
plt.imshow(basis[i], cmap='seismic')
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.show()